GSoC Blog 2- Gen tutorial

4 minute read

Published:

Work in Progress Update

The last few weeks have been extremely fulfilling while I learnt more about Gen, Arviz and the entire PPL ecosystem. I’m almost close to the completion of my project and the next 2 weeks I will be focusing on winding up and making any changes required to the code. Since I have been learning to use Gen for some time now I realised it would be helpful to put a tutorial out there.

Some background about Gen

Gen is a platform for probabilistic modeling and inference. Gen includes a set of built-in languages for defining probabilistic models and a standard library for defining probabilistic inference algorithms, but is designed to be extended with an open-ended set of more specialized modeling languages and inference libraries.

A simple probablistic program in Julia

Consider the following Julia code:

using Gen: uniform_discrete, bernoulli, categorical

function f(p)
    n = uniform_discrete(1, 10)
    if bernoulli(p)
        n *= 2
    end
    return categorical([i == n ? 0.5 : 0.5/19 for i=1:20])
end;

The function f calls three functions provided by Gen, each of which returns a random value, sampled from a certain probability distribution:

a. uniform_discrete(a, b) returns an integer uniformly sampled from the set {a, .., b}

b. bernoulli (x) returns true with probability p and false with probability 1-x.

c. categorical(probs) returns the integer i with probability probs[i] for i in the set {1, .., length(probs)}.

These are three of the many probability distributions that are provided by Gen.

The function f first sets the initial value of n to a random value drawn from the set of integers {1, .., 10}:

   n = uniform_discrete(1, 10)

Then, with probability p, it multiplies n by two:

    if bernoulli(p)
        n *= 2
    end

Then, it samples an integer in the set {1, …, 20}. With probability 0.5 the integer is n, and with probability 0.5 it is uniformly chosen from the remaining 19 integers. It returns this sampled integer:

    return categorical([i == n ? 0.5 : 0.5/19 for i=1:20])

If we run this function many times, we can see the probability distribution on its return values. The distribution depends on the argument p to the function:

using Plots

bins = collect(range(0, 21))

function plot_histogram(p)
    histogram([f(p) for _=1:100000], bins=bins, title="p=$p", label=nothing)
end

plot(map(plot_histogram, [0.1, 0.5, 0.9])...)

Suppose we wanted to see what the distribution on return values would be if the initial value of n was 2. Because we don’t know what random values were sampled during a given execution, we can’t use simulations of f to answer this question. We would have to modify f first, to return the initial value of n Suppose we wanted to ask more questions. We might need to modify each time we have a new question, to make sure that the function returns the particular pieces of information about the execution that the question requires.

Note that if the function always returned the value of every random choice, then these values are sufficient to answer any question using executions of the function, because all states in the execution of the function are deterministic given the random choices. We will call the record of all the random choies a trace. In order to store all the random choices in the trace, we need to come up with a unique name or address for each random choice.

Below, we implement the trace as a dictionary that maps addresses of random choices to their values. We use a unique Julia Symbol for each address:

function f_with_trace(p)
    trace = Dict()
    
    initial_n = uniform_discrete(1, 10)
    trace[:initial_n] = initial_n
    
    n = initial_n
    
    do_branch = bernoulli(p)
    trace[:do_branch] = do_branch
    
    if do_branch
        n *= 2
    end
    
    result = categorical([i == n ? 0.5 : 0.5/19 for i=1:20])
    trace[:result] = result
    
    return (result, trace)
end;

We run the function, and get the return value and the trace:

f_with_trace(0.3)

Output:(14, Dict{Any, Any}(:result => 14, :do_branch => false, :initial_n => 4))

Thank you for reading!